#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sat Feb 17 11:44:57 2024

@author: jcfq2
"""

import os

jwst_dir='/Users/jcfq2/data/observations/jwst'

os.chdir(jwst_dir)

import matplotlib.pyplot as plt
import h3ppy
import glob
import numpy as np
from importlib import reload
import JWSTSolarSystemPointing as jssp
reload(jssp)
#import jwst_uranus as jwstu

# In[0]

kernel_dir = '/Users/jcfq2/data/observations/jwst/kernels/'
jssp.load_kernels(kdir=kernel_dir)

# Set up a h3ppy object too, always useful
h3p_model = h3ppy.h3p()



# %% can we slice with all wavelengths?  No we can't - a hard thing


from JWSTSolarSystemPointing import get_pixel_polygons as jssp_gpp

from pypolyclip import clip_multi


 
files = sorted(glob.glob("/Users/jcfq2/data/observations/jwst/5308_dither_separated/*.fits"))
 


geo = jssp.JWSTSolarSystemPointing(files[0])
geoim=geo.im
# files=files[10]
# do for all files or just the first one

lens=len(files)
# lens=4

# lens=1
# set teh gridding scale, here, 10 equals 0.1, 1 equals 1

scale=2

#450 for 30 degrees of wrap around either size
# naxis is weird, if it is too small it skips anything outside the range, but can be any size bigger than the range of values clipped

naxis = (450*scale,181*scale)

# NB: I am using 450, to give a 45 degree buffer for the 360 wrap around. Not sure this is actually working well, but....
naxis_mapcount=np.zeros((int(450*scale+1),int(180*scale+1),geoim[:,0,0].size))
naxis_mapint=np.zeros((int(450*scale+1),int(180*scale+1)))
naxis_mapspec=np.zeros((int(450*scale+1),int(180*scale+1),geoim[:,0,0].size))
naxis_maplos=np.zeros((int(450*scale+1),int(180*scale+1)))

#  the two arrays the brightness goes into,
 
 
for img_num in range(lens):
 
    # img_num=4
    file=files[img_num]
    print(img_num+1,'/',lens,': ',file)
    geo = jssp.JWSTSolarSystemPointing(file)
    cube = geo.full_fov()
    wave = geo.get_wavelength()
    rayheight = cube[10, :, :]
 
    # polyedge=jssp_gpp(geo,xaxis = 'localtime')#.get_pixel_polygons_radec()

    polyedge=jssp_gpp(geo)#.get_pixel_polygons_radec()
 
    

    #  just some stuff to get brightness for saturn. mostly needs improving

    # geoim=geo.im[:,:,1:-1]
    
    # ab=geo.im[2239,:,:]+geo.im[1568,:,:]+geo.im[1642,:,:]+geo.im[1010,:,:]+geo.im[1011,:,:]+geo.im[1018,:,:]+geo.im[1019,:,:]
    # cd=(geo.im[2241,:,:]*0.8*0.5+geo.im[2236,:,:]*1.19*0.5)+((geo.im[1571,:,:]+geo.im[1565,:,:])*0.5)+(0.4*geo.im[1640,:,:]/5.21)+(geo.im[1004,:,:]+geo.im[1005,:,:]*2)


    ab=geo.im[683,:,:]
    cd=geo.im[683,:,:]*0. #methane?
 
    # ab=geo.im[3487,:,:]
    # cd=geo.im[3487,:,:]=0. #heat

    # ab=geo.im[214,:,:]
    # cd=geo.im[214,:,:]=0. #heat
 
    # ab[:,0]=np.nan
    # ab[:,-1]=np.nan
    # ab[13:24,2:3]=np.nan
    
    # plt.imshow(ab)
    # plt.show()
    
    ab[:,0:4]=np.nan
    ab[:,-3:]=np.nan    
    ab[0,:]=np.nan

    # ab[:,0:4]=np.nan
    # ab[:,-4:]=np.nan    
    # ab[0:4,:]=np.nan
    # ab[-4:,:]=np.nan    

    # ab[:,0:10]=np.nan
    # ab[:,25:]=np.nan    
    # ab[0:14,:]=np.nan
    # ab[15:,:]=np.nan    

    img_data=ab-cd

    geoim=geo.im
    # remove too close to limb


    #plt.imshow(ab)
    #plt.show()
    print(img_num+1,'/',lens,': plotting done')

    img_data[rayheight>-500] = np.nan
    
    i0=0
    
    
    los_corr=np.cos(np.radians(cube[6,:,:]))

    R_dist=np.sin(np.radians(90-np.degrees(los_corr)))

    R1= 54364.
    R2=(R1+900.)/R1
    R3=(R1+1400.)/R1
                
    RR=R_dist

    RR_0=0

    los_coor=(np.sqrt(R3**2 - RR**2) - np.sqrt(R2**2-RR**2))
    los_coor_0=(np.sqrt(R3**2 - RR_0**2) - np.sqrt(R2**2-RR_0**2))

    los_corr=1/(los_coor/los_coor_0)
    
    
    # for ixi in range(1300):
    #     # geoim[ixi,:,:]=geoim[ixi,:,:]*los_corr
    #     geoim[ixi,:,:]=los_corr

    # I owuld like to explain the below, but I got it working by the skin of my teeth - it uses slices (a python thing) and kicks out results in the way it does

    # map out using pixel slices

    for i_p in range(len(polyedge[0])): 

    # for i in range(1000): 

        if (len(polyedge[0][i_p]) == 4): 

            a=np.array(polyedge[0][i_p])

            int_img_data=  img_data[polyedge[1][i_p][0],polyedge[1][i_p][1]]
            los_data=  los_corr[polyedge[1][i_p][0],polyedge[1][i_p][1]]
            geoimg_data=  geoim[:,polyedge[1][i_p][0],polyedge[1][i_p][1]]
            # to be tested! just creating a mess

            if ~np.isnan(a[0][0]):

                if ~np.isnan(int_img_data):

                    # print(i,len(polyedge[0][i]))

                    pxx=(a[:,0]+30)*scale

                    pyy=(a[:,1]+90)*scale


                    if i0 == 0:

                        px=(a[:,0]+30)*scale
                        py=(a[:,1]+90)*scale
                        pi=int_img_data
                        geopi=geoimg_data
                        li=los_data
                    else: #i0 is set to one on first iteration, if not set, then make px and py

                        pxx=(a[:,0]+30)*scale
                        pyy=(a[:,1]+90)*scale
                        pii=int_img_data
                        geopii=geoimg_data
                        lii=los_data

                        px= np.vstack([px, pxx])
                        py= np.vstack([py, pyy])
                        pi= np.hstack([pi, pii])
                        geopi= np.vstack([geopi, geopii])
                        li=np.hstack([li, lii])
                        # not certain, but it feels like this stack is working as hoped - issues must be from later...

                    i0=i0+1

    print(img_num+1,'/',lens,': polyedge done')

    xc, yc, area, slices = clip_multi(px, py, naxis)
 
    # compute the total area by summing over all the pixels for each polygon

    A0 = np.asarray([np.sum(area[s]) for s in slices], dtype=float)
 
    # make arrays to capture the values

    m_counts=np.zeros([len(area),geoim[:,0,0].size])

    m_int=np.zeros_like(area)
    m_los=np.zeros_like(area)
 
    m_spec=np.zeros([len(area),geoim[:,0,0].size])
   
 
    # itrator for slices/pixels (where slices are bunched by pixels)

    ixel=0

    for s in slices: 


        # pypolyclip gives a value that is the proportion of the lat long pixel (which might be right?), here, I instead use the proportion of the pixel, so all latlong subpixels sum to a count of 1 - might be the wrong way to do this, but biases towards the smallest (most face on) pixels.  Needs tetsing!

        # m_int[s] = pi[ixel]*(area[s]/np.sum(area[s]))
        # m_los[s] = li[ixel]*(area[s]/np.sum(area[s]))
        
        
        m_int[s] = pi[ixel]*(area[s])
        m_los[s] = li[ixel]*(area[s])

        
        # m_int[s] = pi[ixel]*(area[s])
        
        for ww in range(len(geoim)):
            if ~np.isnan(geopi[ixel,ww]):
                m_spec[s,ww] = geopi[ixel,ww]*(area[s]/np.sum(area[s]))
                # m_spec[s,ww] = geopi[ixel,ww]*(area[s])
                m_counts[s,ww] = area[s]/np.sum(area[s])
                # m_counts[s,ww] = area[s]
            else: 
                m_spec[s,ww] = 0
                m_counts[s,ww] = 0
            # if (ixel == 0): 
            #     if int(ww/10) == ww/10: print(ww,geopi[ixel,ww])
        # this loop is working! but a little slow, even at 1
        
        
        # this below is broken!
        
        # print(ixel, pi[ixel])   
        # geopipi = np.tile(geopi[ixel,:],(len(area[s]),1))
        # areapipi = np.rot90(np.tile((area[s]/np.sum(area[s])), ( len(geoim),1)))
        
        # m_spec[s,:]=geopipi*areapipi
        
        ixel=ixel+1
 
        # print(ixel)
 
    # fill individual map positions with the above
    print(img_num+1,'/',lens,': slices done')

    naxis_mapcount[xc,yc,:]=naxis_mapcount[xc,yc,:]+m_counts

    naxis_mapint[xc,yc]=naxis_mapint[xc,yc]+m_int
    naxis_maplos[xc,yc]=naxis_maplos[xc,yc]+m_los
    naxis_mapspec[xc,yc,:]=naxis_mapspec[xc,yc,:]+m_spec
 
 
    # if plot:

    # _plot(px, py, xc, yc, area, slices)
 
# plt.imshow(np.rot90(naxis_mapint[int(45*scale):int(360*scale),:]/naxis_mapcount[int(45*scale):int(360*scale),:])**0.1,vmin=3.5e2**0.1,vmax=8e3**0.1)

# np.savez_compressed('saturn_spectra_map_2_equal.npy',naxis_mapspec)
# np.savez_compressed('saturn_spectra_count_2_equal.npy',naxis_mapcount)


np.savez_compressed('saturn_spectra_map_2.npy',naxis_mapspec)
np.savez_compressed('saturn_spectra_count_2.npy',naxis_mapcount)


plt.imshow(np.rot90(naxis_mapspec[int(45*scale):int(360*scale+45*scale),:,683]/naxis_mapcount[int(45*scale):int(360*scale+45*scale),:,683])**0.1,vmin=3.5e2**0.1,vmax=8e3**0.1)

# np.savez_compressed('saturn_spectra_los_shell_2_equal.npy',naxis_maplos)
np.savez_compressed('saturn_spectra_los_shell_2.npy',naxis_maplos)

plt.imshow(np.rot90(naxis_maplos[int(45*scale):int(360*scale+45*scale),:]/naxis_mapcount[int(45*scale):int(360*scale+45*scale),:,683]),vmin=0,vmax=1)



plt.show()


